keras

您所在的位置:网站首页 keras h5文件无法加载 keras

keras

2023-09-15 06:31| 来源: 网络整理| 查看: 265

一、前言

前段时间基于某个项目的需求,需要将自己魔改的yolo框架部署到单片机上(由于是第一次部署,真是一波三折啊~~),其中的一个环节便是要将keras训练好保存的h5文件转换成tflite(我h5保存的是weights)。接下来,我将就这一环节详细说明下,让各位少踩点坑(因为各类包的版本问题,还是有不少坑的)。

首先说明下我使用的库的版本(对于模型转换主要关注下keras和tensorflow的版本就行了)

python - 3.6.12 keras - 2.2.4 tensorflow - 1.13.1

完整的模型转换代码已上传至GitHub:https://github.com/DeepVegChicken/Learning-Yolo_Deployment

二、convert h5 to pd

开始我照着网上的步骤直接用h5转tflite时,发现不是这个报错就是那个报错,心态炸裂。后面看到有人说可以先将h5转成pd格式,然后再转成tflite就好转了,于是就去试了试,发现h5转pd也报错,仔细查阅了资料后发现是转换的姿势有问题。我转换失败的原因有两点:

由于yolo的训练结构中存在着大量的自定义结构,网上的程序基本上都是使用load_model函数来加载模型,而load_model只能加载keras中默认已经存在的网络结构,这就导致了模型加载不起来普遍keras版本的yolo训练完成后都是使用model.save_weights()函数来进行保存的,然而这个函数最终保存下来的只有训练好的权值,并不包含网络的结构(不信那么可以去对比下model.save_weights()和model.save(),肯定是后面这个函数保存的文件较大,因为它还包含了网络的结构),而我们的yolo也是基于model.save_weights(),只保存了训练好的权值,没有保存网络的结构,从而导致模型无法加载。

为了解决这个问题,我们可以建立一个空权值的网络,然后用这个网络加载上保存的对应的权值,这样就能够构成一个完成的h5模型文件了:

h5Model = YOLONet(Input(shape=(320, 320, 3)), 3, 1) h5Model.load_weights(h5Model_path)

既然解决完模型导入这个问题,那么对于h5转pd就基本没问题了,我的h5转pd的代码如下:

# 这两个是我自己模型的文件位置 from Model import * from Utils import * import os.path as osp from keras import backend as K from tensorflow.python.framework import graph_util, graph_io from tensorflow.python.tools import import_pb_to_tensorboard def h5_to_pb(h5_model, output_dir, model_name, out_prefix, log_tensorboard=False): if osp.exists(output_dir) == False: os.mkdir(output_dir) # 清楚自己网络输入和输出节点的名称至关重要,这关乎着后面pd转tflite的成功与否。 # 如果你不清楚你输出节点名字可以重新改下输出的名称 # 修改输出结点名称 -> output_1, output_2 # 由于我的网络只有两个输出,所以就是'output_1'和'output_2' out_nodes = [] for i in range(len(h5_model.outputs)): out_nodes.append(str(out_prefix) + str(i + 1)) tf.identity(h5_model.output[i], str(out_prefix) + str(i + 1)) sess = K.get_session() init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes) graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False) if log_tensorboard: import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir, model_name), output_dir) h5Model_path = 'DataSet/Weights/TrainedWeights_Tiger.h5' pbOutput_path = 'DataSet/Weights/' pbOutput_name = 'TrainedWeights_Tiger.pb' # # 原输入与输出结点名称 # output_node_names = ["input_1:0", "conv2d_36/BiasAdd:0", "conv2d_38/BiasAdd:0"] # 创建网络 h5Model = YOLONet(Input(shape=(320, 320, 3)), 3, 1) # 利用创建的空网络将训练好的权值加载到网络中 h5Model.load_weights(h5Model_path) # h5_to_pb convert h5_to_pb(h5Model, output_dir=pbOutput_path, model_name=pbOutput_name, out_prefix='output_') 三、convert pd to tflite

如果你担心你网络输出的节点名称没有改变,又或是你想看下输入输出节点的名称,那么可以通过下面找个代码实现:

import os import tensorflow as tf def create_graph(model_dir, model_name): with tf.gfile.FastGFile(os.path.join(model_dir, model_name), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') pdModel_path = 'DataSet/Weights/' pdModel_name = 'TrainedWeights_Tiger.pb' create_graph(pdModel_path, pdModel_name) tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node] for tensor_name in tensor_name_list: print(tensor_name, '\n')

对了,这个查看代码的输入文件是pd格式的,如果你想要h5格式的,你网上再找找,应该都有的。

万事俱备,只欠东风,接下来就容易多了,我的pd转tflite的代码如下:

import tensorflow as tf pbInput_path = r'DataSet/Weights/TrainedWeights_Tiger.pb' tfliteOutput_path = r'DataSet/Weights/TrainedWeights_Tiger.tflite' # 你网络输入节点的名称 input_tensor_name = ['input_1'] # 你网络输出节点的名称,由于我的网络只有两个输出,所以就是'output_1'和'output_2' # 如果你的网络输出不止两个或者只有一个,你相应的增删就行了 output_tensor_name = ['output_1', 'output_2'] # 你网络输入节点的名称以及其对应的输入尺寸,如果你是多输入的网络,那么你直接逗号继续添加就行了 # 如:input_tensor_shape = {'input_1': [1, 320, 320, 3], 'input_2': [1, 640, 640, 3]} input_tensor_shape = {'input_1': [1, 320, 320, 3]} converter = tf.lite.TFLiteConverter.from_frozen_graph( pbInput_path, input_arrays=input_tensor_name, output_arrays=output_tensor_name, input_shapes=input_tensor_shape) tfliteModel = converter.convert() open(tfliteOutput_path, "wb").write(tfliteModel)

最终生成的文件如下: 在这里插入图片描述 可以看到.tflite文件是最小的,其实还可以更小点,就是在pd转tflite中对其进行uint8量化,就在上面代码加几句话就行了,具体的实现那么网上找找看看(我好懒啊,哈哈 ‘_’ )



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3